# Copyright (c) HySoP 2011-2024
#
# This file is part of HySoP software.
# See "https://particle_methods.gricad-pages.univ-grenoble-alpes.fr/hysop-doc/"
# for further info.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sympy as sm
import numpy as np
from hysop.tools.htypes import check_instance, first_not_None
from hysop.symbolic import time_symbol, space_symbols
from hysop.symbolic import AppliedUndef, UndefinedFunction
from hysop.symbolic.base import TensorBase, ScalarBaseTag
[docs]
class FunctionBase(ScalarBaseTag):
"""Base for symbolic functions."""
def __new__(cls, *args, **kwds):
fn = kwds.pop("fn", None)
obj = super().__new__(cls, *args, **kwds)
obj.fn = staticmethod(fn)
return obj
def __init__(self, *args, **kwds):
kwds.pop("fn")
super().__init__(*args, **kwds)
def _hashable_content(self):
"""See sympy.core.basic.Basic._hashable_content()"""
hc = super()._hashable_content()
hc += (self.fn,)
return hc
[docs]
class SymbolicFunction(FunctionBase, UndefinedFunction):
"""Unapplied symbolic scalar function."""
def __new__(cls, name, fn=None, bases=None, **kwds):
bases = first_not_None(bases, (AppliedSymbolicFunction,))
return super().__new__(cls, bases=bases, name=name, fn=fn, **kwds)
def __init__(self, name, fn=None, bases=None, **kwds):
super().__init__(bases=bases, name=name, fn=fn, **kwds)
[docs]
class AppliedSymbolicFunction(AppliedUndef):
"""Applied symbolic scalar function."""
def __new__(cls, *args, **kwds):
return super().__new__(cls, *args, **kwds)
def __init__(self, *args, **kwds):
super().__init__(**kwds)
[docs]
def freplace(self):
if self.fn is not None:
return self.fn(*self.args)
else:
return self
def __call__(self):
return self.freplace()
[docs]
class SymbolicFunctionTensor(TensorBase):
"""Symbolic tensor symbol."""
def __new__(
cls,
shape,
name=None,
fn=None,
init=None,
scalar_cls=None,
scalar_kwds=None,
**kwds,
):
scalar_cls = first_not_None(scalar_cls, SymbolicFunction)
scalar_kwds = first_not_None(scalar_kwds, {})
scalar_kwds.setdefault("fn", fn)
return super().__new__(
cls,
name=name,
shape=shape,
init=init,
scalar_cls=scalar_cls,
scalar_kwds=scalar_kwds,
**kwds,
)
def __init__(
self,
shape,
name=None,
fn=None,
init=None,
scalar_cls=None,
scalar_kwds=None,
**kwds,
):
super().__init__(
name=name,
shape=None,
init=None,
scalar_cls=scalar_cls,
scalar_kwds=scalar_kwds,
**kwds,
)
def __call__(self, *args, **kwds):
return self.elementwise_fn(lambda x: x(*args, **kwds))
[docs]
def freplace(self):
return self.elementwise_fn(lambda x: x.freplace())
if __name__ == "__main__":
def fn(x0, x1):
return x0 - x1
f = SymbolicFunction("f", fn=sm.cos)
g = SymbolicFunction("g", fn=lambda *x: sm.sin(np.prod(x)))
h = SymbolicFunction("h", fn=lambda *x: sm.tan(np.sum(x)))
i = SymbolicFunction("i", fn=fn)
j = SymbolicFunction("j", fn=None)
F = f(time_symbol)
G = g(*space_symbols)
H = h(5, *space_symbols)
I = i(5, space_symbols[1])
J = j(time_symbol, space_symbols[0])
a = SymbolicFunctionTensor(name="a", shape=(2, 2))
b = SymbolicFunctionTensor(name="b", shape=(4,), fn=sm.cos)
A = a(time_symbol)
B = b(space_symbols[0])
print(f)
print(g)
print(h)
print(i)
print(j)
print()
print(type(f).__mro__[:2])
print()
print(F)
print(G)
print(H)
print(I)
print(J)
print()
print(type(F).__mro__[:4])
print()
print(F())
print(G())
print(H())
print(I())
print(J())
print()
print(a)
print(b)
print()
print(A)
print(B)
print()
print(A.freplace())
print(B.freplace())